| """HSIGene VAE blocks - ResnetBlock, Encoder, Decoder, DiagonalGaussianDistribution.""" | |
| from typing import Optional, Any | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| try: | |
| import xformers | |
| import xformers.ops | |
| XFORMERS_IS_AVAILABLE = True | |
| except ImportError: | |
| XFORMERS_IS_AVAILABLE = False | |
| def nonlinearity(x): | |
| return x * torch.sigmoid(x) | |
| def Normalize(in_channels, num_groups=32): | |
| return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| class ResnetBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| in_channels, | |
| out_channels=None, | |
| conv_shortcut=False, | |
| dropout, | |
| temb_channels=512, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels) | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| if temb_channels > 0: | |
| self.temb_proj = nn.Linear(temb_channels, out_channels) | |
| self.norm2 = Normalize(out_channels) | |
| self.dropout = nn.Dropout(dropout) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| else: | |
| self.nin_shortcut = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x, temb): | |
| h = x | |
| h = self.norm1(h) | |
| h = nonlinearity(h) | |
| h = self.conv1(h) | |
| if temb is not None: | |
| h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] | |
| h = self.norm2(h) | |
| h = nonlinearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| x = self.conv_shortcut(x) | |
| else: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class AttnBlock(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| b, c, h, w = q.shape | |
| q = q.reshape(b, c, h * w).permute(0, 2, 1) | |
| k = k.reshape(b, c, h * w) | |
| w_ = torch.bmm(q, k) * (int(c) ** -0.5) | |
| w_ = F.softmax(w_, dim=2) | |
| v = v.reshape(b, c, h * w) | |
| h_ = torch.bmm(v, w_.permute(0, 2, 1)) | |
| h_ = h_.reshape(b, c, h, w) | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| class MemoryEfficientAttnBlock(nn.Module): | |
| """AttnBlock using xformers when available.""" | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels) | |
| self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.attention_op: Optional[Any] = None | |
| def forward(self, x): | |
| h_ = self.norm(x) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| B, C, H, W = q.shape | |
| q, k, v = map(lambda t: rearrange(t, "b c h w -> b (h w) c"), (q, k, v)) | |
| q, k, v = map( | |
| lambda t: t.unsqueeze(3) | |
| .reshape(B, t.shape[1], 1, C) | |
| .permute(0, 2, 1, 3) | |
| .reshape(B * 1, t.shape[1], C) | |
| .contiguous(), | |
| (q, k, v), | |
| ) | |
| out = xformers.ops.memory_efficient_attention( | |
| q, k, v, attn_bias=None, op=self.attention_op | |
| ) | |
| out = ( | |
| out.unsqueeze(0) | |
| .reshape(B, 1, out.shape[1], C) | |
| .permute(0, 2, 1, 3) | |
| .reshape(B, out.shape[1], C) | |
| ) | |
| out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) | |
| out = self.proj_out(out) | |
| return x + out | |
| def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): | |
| assert attn_type in ["vanilla", "vanilla-xformers", "none"] | |
| if XFORMERS_IS_AVAILABLE and attn_type == "vanilla": | |
| attn_type = "vanilla-xformers" | |
| if attn_type == "vanilla": | |
| return AttnBlock(in_channels) | |
| elif attn_type == "vanilla-xformers": | |
| return MemoryEfficientAttnBlock(in_channels) | |
| elif attn_type == "none": | |
| return nn.Identity() | |
| raise NotImplementedError(f"attn_type {attn_type}") | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels, with_conv): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) | |
| def forward(self, x): | |
| if self.with_conv: | |
| pad = (0, 1, 0, 1) | |
| x = F.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| else: | |
| x = F.avg_pool2d(x, kernel_size=2, stride=2) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels, with_conv): | |
| super().__init__() | |
| self.with_conv = with_conv | |
| if self.with_conv: | |
| self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| x = F.interpolate(x, scale_factor=2.0, mode="nearest") | |
| if self.with_conv: | |
| x = self.conv(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| ch, | |
| out_ch, | |
| ch_mult=(1, 2, 4, 8), | |
| num_res_blocks, | |
| attn_resolutions, | |
| dropout=0.0, | |
| resamp_with_conv=True, | |
| in_channels, | |
| resolution, | |
| z_channels, | |
| double_z=True, | |
| use_linear_attn=False, | |
| attn_type="vanilla", | |
| **ignore_kwargs, | |
| ): | |
| super().__init__() | |
| if use_linear_attn: | |
| attn_type = "linear" | |
| self.ch = ch | |
| self.temb_ch = 0 | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) | |
| curr_res = resolution | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| self.down = nn.ModuleList() | |
| for i_level in range(self.num_resolutions): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for i_block in range(num_res_blocks): | |
| block.append( | |
| ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(make_attn(block_in, attn_type=attn_type)) | |
| down = nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if i_level != self.num_resolutions - 1: | |
| down.downsample = Downsample(block_in, resamp_with_conv) | |
| curr_res = curr_res // 2 | |
| self.down.append(down) | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) | |
| self.mid.block_2 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| self.norm_out = Normalize(block_in) | |
| self.conv_out = nn.Conv2d( | |
| block_in, 2 * z_channels if double_z else z_channels, | |
| kernel_size=3, stride=1, padding=1 | |
| ) | |
| def forward(self, x): | |
| temb = None | |
| hs = [self.conv_in(x)] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1], temb) | |
| if len(self.down[i_level].attn) > 0: | |
| h = self.down[i_level].attn[i_block](h) | |
| hs.append(h) | |
| if i_level != self.num_resolutions - 1: | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| h = hs[-1] | |
| h = self.mid.block_1(h, temb) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h, temb) | |
| h = self.norm_out(h) | |
| h = nonlinearity(h) | |
| h = self.conv_out(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| ch, | |
| out_ch, | |
| ch_mult=(1, 2, 4, 8), | |
| num_res_blocks, | |
| attn_resolutions, | |
| dropout=0.0, | |
| resamp_with_conv=True, | |
| in_channels, | |
| resolution, | |
| z_channels, | |
| give_pre_end=False, | |
| tanh_out=False, | |
| use_linear_attn=False, | |
| attn_type="vanilla", | |
| **ignore_kwargs, | |
| ): | |
| super().__init__() | |
| if use_linear_attn: | |
| attn_type = "linear" | |
| self.ch = ch | |
| self.temb_ch = 0 | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.give_pre_end = give_pre_end | |
| self.tanh_out = tanh_out | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| block_in = ch * ch_mult[self.num_resolutions - 1] | |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
| self.z_shape = (1, z_channels, curr_res, curr_res) | |
| self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) | |
| self.mid.block_2 = ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| self.up = nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_out = ch * ch_mult[i_level] | |
| for i_block in range(self.num_res_blocks + 1): | |
| block.append( | |
| ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(make_attn(block_in, attn_type=attn_type)) | |
| up = nn.Module() | |
| up.block = block | |
| up.attn = attn | |
| if i_level != 0: | |
| up.upsample = Upsample(block_in, resamp_with_conv) | |
| curr_res = curr_res * 2 | |
| self.up.insert(0, up) | |
| self.norm_out = Normalize(block_in) | |
| self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) | |
| def forward(self, z): | |
| self.last_z_shape = z.shape | |
| temb = None | |
| h = self.conv_in(z) | |
| h = self.mid.block_1(h, temb) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h, temb) | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| h = self.up[i_level].block[i_block](h, temb) | |
| if len(self.up[i_level].attn) > 0: | |
| h = self.up[i_level].attn[i_block](h) | |
| if i_level != 0: | |
| h = self.up[i_level].upsample(h) | |
| if self.give_pre_end: | |
| return h | |
| h = self.norm_out(h) | |
| h = nonlinearity(h) | |
| h = self.conv_out(h) | |
| if self.tanh_out: | |
| h = torch.tanh(h) | |
| return h | |
| class DiagonalGaussianDistribution: | |
| def __init__(self, parameters, deterministic=False): | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -20.0, 0.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean, device=parameters.device) | |
| def sample(self): | |
| x = self.mean + self.std * torch.randn( | |
| self.mean.shape, device=self.parameters.device | |
| ) | |
| return x | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return torch.tensor(0.0, device=self.parameters.device) | |
| if other is None: | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, | |
| dim=[1, 2, 3], | |
| ) | |
| return 0.5 * torch.sum( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var | |
| - 1.0 | |
| - self.logvar | |
| + other.logvar, | |
| dim=[1, 2, 3], | |
| ) | |
| def nll(self, sample, dims=[1, 2, 3]): | |
| if self.deterministic: | |
| return torch.tensor(0.0, device=self.parameters.device) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum( | |
| logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | |
| dim=dims, | |
| ) | |
| def mode(self): | |
| return self.mean | |