Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| from utils.typing import * | |
| from .attention import MemEffAttention | |
| class VolumeAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| groups: int = 32, | |
| eps: float = 1e-5, | |
| residual: bool = True, | |
| skip_scale: float = 1, | |
| ): | |
| super().__init__() | |
| self.residual = residual | |
| self.skip_scale = skip_scale | |
| self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) | |
| self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) | |
| def forward(self, x): | |
| # x: [B, C, H, W, D] | |
| B, C, H, W, D = x.shape | |
| res = x | |
| x = self.norm(x) | |
| x = x.permute(0, 2, 3, 4, 1).reshape(B, -1, C) | |
| x = self.attn(x) | |
| x = x.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3).reshape(B, C, H, W, D) | |
| if self.residual: | |
| x = (x + res) * self.skip_scale | |
| return x | |
| class DiagonalGaussianDistribution: | |
| def __init__(self, parameters, deterministic=False): | |
| # parameters: [B, 2C, ...] | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) | |
| def sample(self): | |
| sample = torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype) | |
| x = self.mean + self.std * sample | |
| return x | |
| def kl(self, other=None, dims=[1, 2, 3, 4]): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| else: | |
| if other is None: | |
| return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims) | |
| else: | |
| return 0.5 * torch.mean( | |
| torch.pow(self.mean - other.mean, 2) / other.var | |
| + self.var / other.var | |
| - 1.0 | |
| - self.logvar | |
| + other.logvar, | |
| dim=dims, | |
| ) | |
| def nll(self, sample, dims=[1, 2, 3, 4]): | |
| if self.deterministic: | |
| return torch.Tensor([0.0]) | |
| logtwopi = np.log(2.0 * np.pi) | |
| return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) | |
| def mode(self): | |
| return self.mean | |
| class ResnetBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| resample: Literal['default', 'up', 'down'] = 'default', | |
| groups: int = 32, | |
| eps: float = 1e-5, | |
| skip_scale: float = 1, # multiplied to output | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.skip_scale = skip_scale | |
| self.norm1 = nn.GroupNorm(num_groups=min(groups, in_channels), num_channels=in_channels, eps=eps, affine=True) | |
| self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.norm2 = nn.GroupNorm(num_groups=min(groups, out_channels), num_channels=out_channels, eps=eps, affine=True) | |
| self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| self.act = F.silu | |
| self.resample = None | |
| if resample == 'up': | |
| self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") | |
| elif resample == 'down': | |
| self.resample = nn.AvgPool3d(kernel_size=2, stride=2) | |
| self.shortcut = nn.Identity() | |
| if self.in_channels != self.out_channels: | |
| self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True) | |
| def forward(self, x): | |
| res = x | |
| x = self.norm1(x) | |
| x = self.act(x) | |
| if self.resample: | |
| res = self.resample(res) | |
| x = self.resample(x) | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| x = self.act(x) | |
| x = self.conv2(x) | |
| x = (x + self.shortcut(res)) * self.skip_scale | |
| return x | |
| class DownBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| downsample: bool = True, | |
| skip_scale: float = 1, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.gradient_checkpointing = gradient_checkpointing | |
| nets = [] | |
| for i in range(num_layers): | |
| cin = in_channels if i == 0 else out_channels | |
| nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale)) | |
| self.nets = nn.ModuleList(nets) | |
| self.downsample = None | |
| if downsample: | |
| self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) | |
| def forward(self, x): | |
| if self.training and self.gradient_checkpointing: | |
| return checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| def _forward(self, x): | |
| for net in self.nets: | |
| x = net(x) | |
| if self.downsample: | |
| x = self.downsample(x) | |
| return x | |
| class MidBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| num_layers: int = 1, | |
| attention: bool = True, | |
| attention_heads: int = 8, | |
| skip_scale: float = 1, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.gradient_checkpointing = gradient_checkpointing | |
| nets = [] | |
| attns = [] | |
| # first layer | |
| nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) | |
| # more layers | |
| for i in range(num_layers): | |
| nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) | |
| if attention: | |
| attns.append(VolumeAttention(in_channels, attention_heads, skip_scale=skip_scale)) | |
| else: | |
| attns.append(None) | |
| self.nets = nn.ModuleList(nets) | |
| self.attns = nn.ModuleList(attns) | |
| def forward(self, x): | |
| if self.training and self.gradient_checkpointing: | |
| return checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| def _forward(self, x): | |
| x = self.nets[0](x) | |
| for attn, net in zip(self.attns, self.nets[1:]): | |
| if attn: | |
| x = attn(x) | |
| x = net(x) | |
| return x | |
| class UpBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_layers: int = 1, | |
| upsample: bool = True, | |
| skip_scale: float = 1, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| self.gradient_checkpointing = gradient_checkpointing | |
| nets = [] | |
| for i in range(num_layers): | |
| cin = in_channels if i == 0 else out_channels | |
| nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale)) | |
| self.nets = nn.ModuleList(nets) | |
| self.upsample = None | |
| if upsample: | |
| self.upsample = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) | |
| def forward(self, x): | |
| if self.training and self.gradient_checkpointing: | |
| return checkpoint(self._forward, x, use_reentrant=False) | |
| else: | |
| return self._forward(x) | |
| def _forward(self, x): | |
| for net in self.nets: | |
| x = net(x) | |
| if self.upsample: | |
| x = self.upsample(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 1, | |
| out_channels: int = 2 * 16, # double_z | |
| down_channels: Tuple[int, ...] = (8, 16, 32, 64), | |
| mid_attention: bool = True, | |
| layers_per_block: int = 2, | |
| skip_scale: float = np.sqrt(0.5), | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| # input (first downsample) | |
| self.conv_in = nn.Conv3d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) | |
| # down | |
| down_blocks = [] | |
| cout = down_channels[0] | |
| for i in range(len(down_channels)): | |
| cin = cout | |
| cout = down_channels[i] | |
| down_blocks.append(DownBlock( | |
| cin, cout, | |
| num_layers=layers_per_block, | |
| downsample=(i != len(down_channels) - 1), # not final layer | |
| skip_scale=skip_scale, | |
| gradient_checkpointing=gradient_checkpointing, | |
| )) | |
| self.down_blocks = nn.ModuleList(down_blocks) | |
| # mid | |
| self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) | |
| # last | |
| self.norm_out = nn.GroupNorm(num_channels=down_channels[-1], num_groups=32, eps=1e-5) | |
| self.conv_out = nn.Conv3d(down_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| # x: [B, Cin, H, W, D] | |
| # first | |
| x = self.conv_in(x) | |
| # down | |
| for block in self.down_blocks: | |
| x = block(x) | |
| # mid | |
| x = self.mid_block(x) | |
| # last | |
| x = self.norm_out(x) | |
| x = F.silu(x) | |
| x = self.conv_out(x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 16, | |
| out_channels: int = 1, | |
| up_channels: Tuple[int, ...] = (64, 32, 16, 8), | |
| mid_attention: bool = True, | |
| layers_per_block: int = 2, | |
| skip_scale: float = np.sqrt(0.5), | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| # first | |
| self.conv_in = nn.Conv3d(in_channels, up_channels[0], kernel_size=3, stride=1, padding=1) | |
| # mid | |
| self.mid_block = MidBlock(up_channels[0], attention=mid_attention, skip_scale=skip_scale) | |
| # up | |
| up_blocks = [] | |
| cout = up_channels[0] | |
| for i in range(len(up_channels)): | |
| cin = cout | |
| cout = up_channels[i] | |
| up_blocks.append(UpBlock( | |
| cin, cout, | |
| num_layers=layers_per_block, | |
| upsample=(i != len(up_channels) - 1), # not final layer | |
| skip_scale=skip_scale, | |
| gradient_checkpointing=gradient_checkpointing, | |
| )) | |
| self.up_blocks = nn.ModuleList(up_blocks) | |
| # last (upsample) | |
| self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=min(32, up_channels[-1]), eps=1e-5) | |
| self.conv_out = nn.ConvTranspose3d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| # x: [B, Cin, H, W, D] | |
| # first | |
| x = self.conv_in(x) | |
| # mid | |
| x = self.mid_block(x) | |
| # up | |
| for block in self.up_blocks: | |
| x = block(x) | |
| # last | |
| x = self.norm_out(x) | |
| x = F.silu(x) | |
| x = self.conv_out(x) | |
| return x | |
| class VAE(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = 1, | |
| latent_channels: int = 16, | |
| out_channels: int = 1, | |
| down_channels: Tuple[int, ...] = (16, 32, 64, 128, 256), | |
| mid_attention: bool = True, | |
| up_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), | |
| layers_per_block: int = 2, | |
| skip_scale: float = np.sqrt(0.5), | |
| gradient_checkpointing: bool = False, | |
| ): | |
| super().__init__() | |
| # encoder | |
| self.encoder = Encoder( | |
| in_channels=in_channels, | |
| out_channels=2 * latent_channels, # double_z | |
| down_channels=down_channels, | |
| mid_attention=mid_attention, | |
| layers_per_block=layers_per_block, | |
| skip_scale=skip_scale, | |
| gradient_checkpointing=gradient_checkpointing, | |
| ) | |
| # decoder | |
| self.decoder = Decoder( | |
| in_channels=latent_channels, | |
| out_channels=out_channels, | |
| up_channels=up_channels, | |
| mid_attention=mid_attention, | |
| layers_per_block=layers_per_block, | |
| skip_scale=skip_scale, | |
| gradient_checkpointing=gradient_checkpointing, | |
| ) | |
| # quant | |
| self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) | |
| self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) | |
| def encode(self, x): | |
| x = self.encoder(x) | |
| x = self.quant_conv(x) | |
| posterior = DiagonalGaussianDistribution(x) | |
| return posterior | |
| def decode(self, x): | |
| x = self.post_quant_conv(x) | |
| x = self.decoder(x) | |
| return x | |
| def forward(self, x, sample=True): | |
| # x: [B, Cin, H, W, D] | |
| p = self.encode(x) | |
| if sample: | |
| z = p.sample() | |
| else: | |
| z = p.mode() | |
| x = self.decode(z) | |
| return x, p |