| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | from diffusers import DiffusionPipeline |
| | import tqdm |
| | import torch |
| |
|
| | |
| | import math |
| |
|
| | import numpy as np |
| | import torch.nn as nn |
| |
|
| | from diffusers.configuration_utils import ConfigMixin |
| | from diffusers.modeling_utils import ModelMixin |
| |
|
| |
|
| | def get_timestep_embedding(timesteps, embedding_dim): |
| | """ |
| | This matches the implementation in Denoising Diffusion Probabilistic Models: |
| | From Fairseq. |
| | Build sinusoidal embeddings. |
| | This matches the implementation in tensor2tensor, but differs slightly |
| | from the description in Section 3.5 of "Attention Is All You Need". |
| | """ |
| | assert len(timesteps.shape) == 1 |
| |
|
| | half_dim = embedding_dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) |
| | emb = emb.to(device=timesteps.device) |
| | emb = timesteps.float()[:, None] * emb[None, :] |
| | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| | if embedding_dim % 2 == 1: |
| | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| | return emb |
| |
|
| |
|
| | def nonlinearity(x): |
| | |
| | return x * torch.sigmoid(x) |
| |
|
| |
|
| | def Normalize(in_channels): |
| | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
| |
|
| |
|
| | class Upsample(nn.Module): |
| | def __init__(self, in_channels, with_conv): |
| | super().__init__() |
| | self.with_conv = with_conv |
| | if self.with_conv: |
| | self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
| |
|
| | def forward(self, x): |
| | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| | if self.with_conv: |
| | x = self.conv(x) |
| | return x |
| |
|
| |
|
| | class Downsample(nn.Module): |
| | def __init__(self, in_channels, with_conv): |
| | super().__init__() |
| | self.with_conv = with_conv |
| | if self.with_conv: |
| | |
| | self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) |
| |
|
| | def forward(self, x): |
| | if self.with_conv: |
| | pad = (0, 1, 0, 1) |
| | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| | x = self.conv(x) |
| | else: |
| | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
| | return x |
| |
|
| |
|
| | class ResnetBlock(nn.Module): |
| | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | out_channels = in_channels if out_channels is None else out_channels |
| | self.out_channels = out_channels |
| | self.use_conv_shortcut = conv_shortcut |
| |
|
| | self.norm1 = Normalize(in_channels) |
| | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| | if temb_channels > 0: |
| | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) |
| | self.norm2 = Normalize(out_channels) |
| | self.dropout = torch.nn.Dropout(dropout) |
| | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| | if self.in_channels != self.out_channels: |
| | if self.use_conv_shortcut: |
| | self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| | else: |
| | self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
| |
|
| | def forward(self, x, temb): |
| | h = x |
| | h = self.norm1(h) |
| | h = nonlinearity(h) |
| | h = self.conv1(h) |
| |
|
| | if temb is not None: |
| | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] |
| |
|
| | h = self.norm2(h) |
| | h = nonlinearity(h) |
| | h = self.dropout(h) |
| | h = self.conv2(h) |
| |
|
| | if self.in_channels != self.out_channels: |
| | if self.use_conv_shortcut: |
| | x = self.conv_shortcut(x) |
| | else: |
| | x = self.nin_shortcut(x) |
| |
|
| | return x + h |
| |
|
| |
|
| | class AttnBlock(nn.Module): |
| | def __init__(self, in_channels): |
| | super().__init__() |
| | self.in_channels = in_channels |
| |
|
| | self.norm = Normalize(in_channels) |
| | self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| | self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| | self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| | self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| |
|
| | def forward(self, x): |
| | h_ = x |
| | h_ = self.norm(h_) |
| | q = self.q(h_) |
| | k = self.k(h_) |
| | v = self.v(h_) |
| |
|
| | |
| | b, c, h, w = q.shape |
| | q = q.reshape(b, c, h * w) |
| | q = q.permute(0, 2, 1) |
| | k = k.reshape(b, c, h * w) |
| | w_ = torch.bmm(q, k) |
| | w_ = w_ * (int(c) ** (-0.5)) |
| | w_ = torch.nn.functional.softmax(w_, dim=2) |
| |
|
| | |
| | v = v.reshape(b, c, h * w) |
| | w_ = w_.permute(0, 2, 1) |
| | h_ = torch.bmm(v, w_) |
| | h_ = h_.reshape(b, c, h, w) |
| |
|
| | h_ = self.proj_out(h_) |
| |
|
| | return x + h_ |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | ch, |
| | out_ch, |
| | ch_mult=(1, 2, 4, 8), |
| | num_res_blocks, |
| | attn_resolutions, |
| | dropout=0.0, |
| | resamp_with_conv=True, |
| | in_channels, |
| | resolution, |
| | use_timestep=True, |
| | ): |
| | super().__init__() |
| | self.ch = ch |
| | self.temb_ch = self.ch * 4 |
| | self.num_resolutions = len(ch_mult) |
| | self.num_res_blocks = num_res_blocks |
| | self.resolution = resolution |
| | self.in_channels = in_channels |
| |
|
| | self.use_timestep = use_timestep |
| | if self.use_timestep: |
| | |
| | self.temb = nn.Module() |
| | self.temb.dense = nn.ModuleList( |
| | [ |
| | torch.nn.Linear(self.ch, self.temb_ch), |
| | torch.nn.Linear(self.temb_ch, self.temb_ch), |
| | ] |
| | ) |
| |
|
| | |
| | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) |
| |
|
| | curr_res = resolution |
| | in_ch_mult = (1,) + tuple(ch_mult) |
| | self.down = nn.ModuleList() |
| | for i_level in range(self.num_resolutions): |
| | block = nn.ModuleList() |
| | attn = nn.ModuleList() |
| | block_in = ch * in_ch_mult[i_level] |
| | block_out = ch * ch_mult[i_level] |
| | for i_block in range(self.num_res_blocks): |
| | block.append( |
| | ResnetBlock( |
| | in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| | ) |
| | block_in = block_out |
| | if curr_res in attn_resolutions: |
| | attn.append(AttnBlock(block_in)) |
| | down = nn.Module() |
| | down.block = block |
| | down.attn = attn |
| | if i_level != self.num_resolutions - 1: |
| | down.downsample = Downsample(block_in, resamp_with_conv) |
| | curr_res = curr_res // 2 |
| | self.down.append(down) |
| |
|
| | |
| | self.mid = nn.Module() |
| | self.mid.block_1 = ResnetBlock( |
| | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| | self.mid.attn_1 = AttnBlock(block_in) |
| | self.mid.block_2 = ResnetBlock( |
| | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| |
|
| | |
| | self.up = nn.ModuleList() |
| | for i_level in reversed(range(self.num_resolutions)): |
| | block = nn.ModuleList() |
| | attn = nn.ModuleList() |
| | block_out = ch * ch_mult[i_level] |
| | skip_in = ch * ch_mult[i_level] |
| | for i_block in range(self.num_res_blocks + 1): |
| | if i_block == self.num_res_blocks: |
| | skip_in = ch * in_ch_mult[i_level] |
| | block.append( |
| | ResnetBlock( |
| | in_channels=block_in + skip_in, |
| | out_channels=block_out, |
| | temb_channels=self.temb_ch, |
| | dropout=dropout, |
| | ) |
| | ) |
| | block_in = block_out |
| | if curr_res in attn_resolutions: |
| | attn.append(AttnBlock(block_in)) |
| | up = nn.Module() |
| | up.block = block |
| | up.attn = attn |
| | if i_level != 0: |
| | up.upsample = Upsample(block_in, resamp_with_conv) |
| | curr_res = curr_res * 2 |
| | self.up.insert(0, up) |
| |
|
| | |
| | self.norm_out = Normalize(block_in) |
| | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) |
| |
|
| | def forward(self, x, t=None): |
| | |
| |
|
| | if self.use_timestep: |
| | |
| | assert t is not None |
| | temb = get_timestep_embedding(t, self.ch) |
| | temb = self.temb.dense[0](temb) |
| | temb = nonlinearity(temb) |
| | temb = self.temb.dense[1](temb) |
| | else: |
| | temb = None |
| |
|
| | |
| | hs = [self.conv_in(x)] |
| | for i_level in range(self.num_resolutions): |
| | for i_block in range(self.num_res_blocks): |
| | h = self.down[i_level].block[i_block](hs[-1], temb) |
| | if len(self.down[i_level].attn) > 0: |
| | h = self.down[i_level].attn[i_block](h) |
| | hs.append(h) |
| | if i_level != self.num_resolutions - 1: |
| | hs.append(self.down[i_level].downsample(hs[-1])) |
| |
|
| | |
| | h = hs[-1] |
| | h = self.mid.block_1(h, temb) |
| | h = self.mid.attn_1(h) |
| | h = self.mid.block_2(h, temb) |
| |
|
| | |
| | for i_level in reversed(range(self.num_resolutions)): |
| | for i_block in range(self.num_res_blocks + 1): |
| | h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) |
| | if len(self.up[i_level].attn) > 0: |
| | h = self.up[i_level].attn[i_block](h) |
| | if i_level != 0: |
| | h = self.up[i_level].upsample(h) |
| |
|
| | |
| | h = self.norm_out(h) |
| | h = nonlinearity(h) |
| | h = self.conv_out(h) |
| | return h |
| |
|
| |
|
| | class Encoder(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | ch, |
| | out_ch, |
| | ch_mult=(1, 2, 4, 8), |
| | num_res_blocks, |
| | attn_resolutions, |
| | dropout=0.0, |
| | resamp_with_conv=True, |
| | in_channels, |
| | resolution, |
| | z_channels, |
| | double_z=True, |
| | **ignore_kwargs, |
| | ): |
| | super().__init__() |
| | self.ch = ch |
| | self.temb_ch = 0 |
| | self.num_resolutions = len(ch_mult) |
| | self.num_res_blocks = num_res_blocks |
| | self.resolution = resolution |
| | self.in_channels = in_channels |
| |
|
| | |
| | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) |
| |
|
| | curr_res = resolution |
| | in_ch_mult = (1,) + tuple(ch_mult) |
| | self.down = nn.ModuleList() |
| | for i_level in range(self.num_resolutions): |
| | block = nn.ModuleList() |
| | attn = nn.ModuleList() |
| | block_in = ch * in_ch_mult[i_level] |
| | block_out = ch * ch_mult[i_level] |
| | for i_block in range(self.num_res_blocks): |
| | block.append( |
| | ResnetBlock( |
| | in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| | ) |
| | block_in = block_out |
| | if curr_res in attn_resolutions: |
| | attn.append(AttnBlock(block_in)) |
| | down = nn.Module() |
| | down.block = block |
| | down.attn = attn |
| | if i_level != self.num_resolutions - 1: |
| | down.downsample = Downsample(block_in, resamp_with_conv) |
| | curr_res = curr_res // 2 |
| | self.down.append(down) |
| |
|
| | |
| | self.mid = nn.Module() |
| | self.mid.block_1 = ResnetBlock( |
| | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| | self.mid.attn_1 = AttnBlock(block_in) |
| | self.mid.block_2 = ResnetBlock( |
| | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| |
|
| | |
| | self.norm_out = Normalize(block_in) |
| | self.conv_out = torch.nn.Conv2d( |
| | block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 |
| | ) |
| |
|
| | def forward(self, x): |
| | |
| |
|
| | |
| | temb = None |
| |
|
| | |
| | hs = [self.conv_in(x)] |
| | for i_level in range(self.num_resolutions): |
| | for i_block in range(self.num_res_blocks): |
| | h = self.down[i_level].block[i_block](hs[-1], temb) |
| | if len(self.down[i_level].attn) > 0: |
| | h = self.down[i_level].attn[i_block](h) |
| | hs.append(h) |
| | if i_level != self.num_resolutions - 1: |
| | hs.append(self.down[i_level].downsample(hs[-1])) |
| |
|
| | |
| | h = hs[-1] |
| | h = self.mid.block_1(h, temb) |
| | h = self.mid.attn_1(h) |
| | h = self.mid.block_2(h, temb) |
| |
|
| | |
| | h = self.norm_out(h) |
| | h = nonlinearity(h) |
| | h = self.conv_out(h) |
| | return h |
| |
|
| |
|
| | class Decoder(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | ch, |
| | out_ch, |
| | ch_mult=(1, 2, 4, 8), |
| | num_res_blocks, |
| | attn_resolutions, |
| | dropout=0.0, |
| | resamp_with_conv=True, |
| | in_channels, |
| | resolution, |
| | z_channels, |
| | give_pre_end=False, |
| | **ignorekwargs, |
| | ): |
| | super().__init__() |
| | self.ch = ch |
| | self.temb_ch = 0 |
| | self.num_resolutions = len(ch_mult) |
| | self.num_res_blocks = num_res_blocks |
| | self.resolution = resolution |
| | self.in_channels = in_channels |
| | self.give_pre_end = give_pre_end |
| |
|
| | |
| | in_ch_mult = (1,) + tuple(ch_mult) |
| | block_in = ch * ch_mult[self.num_resolutions - 1] |
| | curr_res = resolution // 2 ** (self.num_resolutions - 1) |
| | self.z_shape = (1, z_channels, curr_res, curr_res) |
| | print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) |
| |
|
| | |
| | self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) |
| |
|
| | |
| | self.mid = nn.Module() |
| | self.mid.block_1 = ResnetBlock( |
| | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| | self.mid.attn_1 = AttnBlock(block_in) |
| | self.mid.block_2 = ResnetBlock( |
| | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| |
|
| | |
| | self.up = nn.ModuleList() |
| | for i_level in reversed(range(self.num_resolutions)): |
| | block = nn.ModuleList() |
| | attn = nn.ModuleList() |
| | block_out = ch * ch_mult[i_level] |
| | for i_block in range(self.num_res_blocks + 1): |
| | block.append( |
| | ResnetBlock( |
| | in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout |
| | ) |
| | ) |
| | block_in = block_out |
| | if curr_res in attn_resolutions: |
| | attn.append(AttnBlock(block_in)) |
| | up = nn.Module() |
| | up.block = block |
| | up.attn = attn |
| | if i_level != 0: |
| | up.upsample = Upsample(block_in, resamp_with_conv) |
| | curr_res = curr_res * 2 |
| | self.up.insert(0, up) |
| |
|
| | |
| | self.norm_out = Normalize(block_in) |
| | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) |
| |
|
| | def forward(self, z): |
| | |
| | self.last_z_shape = z.shape |
| |
|
| | |
| | temb = None |
| |
|
| | |
| | h = self.conv_in(z) |
| |
|
| | |
| | h = self.mid.block_1(h, temb) |
| | h = self.mid.attn_1(h) |
| | h = self.mid.block_2(h, temb) |
| |
|
| | |
| | for i_level in reversed(range(self.num_resolutions)): |
| | for i_block in range(self.num_res_blocks + 1): |
| | h = self.up[i_level].block[i_block](h, temb) |
| | if len(self.up[i_level].attn) > 0: |
| | h = self.up[i_level].attn[i_block](h) |
| | if i_level != 0: |
| | h = self.up[i_level].upsample(h) |
| |
|
| | |
| | if self.give_pre_end: |
| | return h |
| |
|
| | h = self.norm_out(h) |
| | h = nonlinearity(h) |
| | h = self.conv_out(h) |
| | return h |
| |
|
| |
|
| | class VectorQuantizer(nn.Module): |
| | """ |
| | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly |
| | avoids costly matrix multiplications and allows for post-hoc remapping of indices. |
| | """ |
| |
|
| | |
| | |
| | |
| | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): |
| | super().__init__() |
| | self.n_e = n_e |
| | self.e_dim = e_dim |
| | self.beta = beta |
| | self.legacy = legacy |
| |
|
| | self.embedding = nn.Embedding(self.n_e, self.e_dim) |
| | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
| |
|
| | self.remap = remap |
| | if self.remap is not None: |
| | self.register_buffer("used", torch.tensor(np.load(self.remap))) |
| | self.re_embed = self.used.shape[0] |
| | self.unknown_index = unknown_index |
| | if self.unknown_index == "extra": |
| | self.unknown_index = self.re_embed |
| | self.re_embed = self.re_embed + 1 |
| | print( |
| | f"Remapping {self.n_e} indices to {self.re_embed} indices. " |
| | f"Using {self.unknown_index} for unknown indices." |
| | ) |
| | else: |
| | self.re_embed = n_e |
| |
|
| | self.sane_index_shape = sane_index_shape |
| |
|
| | def remap_to_used(self, inds): |
| | ishape = inds.shape |
| | assert len(ishape) > 1 |
| | inds = inds.reshape(ishape[0], -1) |
| | used = self.used.to(inds) |
| | match = (inds[:, :, None] == used[None, None, ...]).long() |
| | new = match.argmax(-1) |
| | unknown = match.sum(2) < 1 |
| | if self.unknown_index == "random": |
| | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) |
| | else: |
| | new[unknown] = self.unknown_index |
| | return new.reshape(ishape) |
| |
|
| | def unmap_to_all(self, inds): |
| | ishape = inds.shape |
| | assert len(ishape) > 1 |
| | inds = inds.reshape(ishape[0], -1) |
| | used = self.used.to(inds) |
| | if self.re_embed > self.used.shape[0]: |
| | inds[inds >= self.used.shape[0]] = 0 |
| | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) |
| | return back.reshape(ishape) |
| |
|
| | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): |
| | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" |
| | assert rescale_logits == False, "Only for interface compatible with Gumbel" |
| | assert return_logits == False, "Only for interface compatible with Gumbel" |
| | |
| | z = rearrange(z, "b c h w -> b h w c").contiguous() |
| | z_flattened = z.view(-1, self.e_dim) |
| | |
| |
|
| | d = ( |
| | torch.sum(z_flattened**2, dim=1, keepdim=True) |
| | + torch.sum(self.embedding.weight**2, dim=1) |
| | - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) |
| | ) |
| |
|
| | min_encoding_indices = torch.argmin(d, dim=1) |
| | z_q = self.embedding(min_encoding_indices).view(z.shape) |
| | perplexity = None |
| | min_encodings = None |
| |
|
| | |
| | if not self.legacy: |
| | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) |
| | else: |
| | loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) |
| |
|
| | |
| | z_q = z + (z_q - z).detach() |
| |
|
| | |
| | z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() |
| |
|
| | if self.remap is not None: |
| | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) |
| | min_encoding_indices = self.remap_to_used(min_encoding_indices) |
| | min_encoding_indices = min_encoding_indices.reshape(-1, 1) |
| |
|
| | if self.sane_index_shape: |
| | min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) |
| |
|
| | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) |
| |
|
| | def get_codebook_entry(self, indices, shape): |
| | |
| | if self.remap is not None: |
| | indices = indices.reshape(shape[0], -1) |
| | indices = self.unmap_to_all(indices) |
| | indices = indices.reshape(-1) |
| |
|
| | |
| | z_q = self.embedding(indices) |
| |
|
| | if shape is not None: |
| | z_q = z_q.view(shape) |
| | |
| | z_q = z_q.permute(0, 3, 1, 2).contiguous() |
| |
|
| | return z_q |
| |
|
| |
|
| | class VQModel(ModelMixin, ConfigMixin): |
| | def __init__( |
| | self, |
| | ch, |
| | out_ch, |
| | num_res_blocks, |
| | attn_resolutions, |
| | in_channels, |
| | resolution, |
| | z_channels, |
| | n_embed, |
| | embed_dim, |
| | remap=None, |
| | sane_index_shape=False, |
| | ch_mult=(1, 2, 4, 8), |
| | dropout=0.0, |
| | double_z=True, |
| | resamp_with_conv=True, |
| | give_pre_end=False, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.register( |
| | ch=ch, |
| | out_ch=out_ch, |
| | num_res_blocks=num_res_blocks, |
| | attn_resolutions=attn_resolutions, |
| | in_channels=in_channels, |
| | resolution=resolution, |
| | z_channels=z_channels, |
| | n_embed=n_embed, |
| | embed_dim=embed_dim, |
| | remap=remap, |
| | sane_index_shape=sane_index_shape, |
| | ch_mult=ch_mult, |
| | dropout=dropout, |
| | double_z=double_z, |
| | resamp_with_conv=resamp_with_conv, |
| | give_pre_end=give_pre_end, |
| | ) |
| |
|
| | |
| | self.encoder = Encoder( |
| | ch=ch, |
| | out_ch=out_ch, |
| | num_res_blocks=num_res_blocks, |
| | attn_resolutions=attn_resolutions, |
| | in_channels=in_channels, |
| | resolution=resolution, |
| | z_channels=z_channels, |
| | ch_mult=ch_mult, |
| | dropout=dropout, |
| | resamp_with_conv=resamp_with_conv, |
| | double_z=double_z, |
| | give_pre_end=give_pre_end, |
| | ) |
| |
|
| | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) |
| |
|
| | |
| | self.decoder = Decoder( |
| | ch=ch, |
| | out_ch=out_ch, |
| | num_res_blocks=num_res_blocks, |
| | attn_resolutions=attn_resolutions, |
| | in_channels=in_channels, |
| | resolution=resolution, |
| | z_channels=z_channels, |
| | ch_mult=ch_mult, |
| | dropout=dropout, |
| | resamp_with_conv=resamp_with_conv, |
| | give_pre_end=give_pre_end, |
| | ) |
| |
|
| | def encode(self, x): |
| | h = self.encoder(x) |
| | h = self.quant_conv(h) |
| | return h |
| |
|
| | def decode(self, h, force_not_quantize=False): |
| | |
| | if not force_not_quantize: |
| | quant, emb_loss, info = self.quantize(h) |
| | else: |
| | quant = h |
| | quant = self.post_quant_conv(quant) |
| | dec = self.decoder(quant) |
| | return dec |
| |
|
| |
|
| | class DDPM(DiffusionPipeline): |
| |
|
| | def __init__(self, unet, vqvae): |
| | super().__init__() |
| | self.register_modules(unet=unet, vqvae=vqvae) |
| |
|
| | def __call__(self, batch_size=1, generator=None, torch_device=None): |
| | if torch_device is None: |
| | torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | self.unet.to(torch_device) |
| | |
| | image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) |
| | for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): |
| | |
| | clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) |
| | clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) |
| | image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) |
| | clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) |
| |
|
| | |
| | with torch.no_grad(): |
| | noise_residual = self.unet(image, t) |
| |
|
| | |
| | |
| | pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual |
| | pred_mean = torch.clamp(pred_mean, -1, 1) |
| | prev_image = clipped_coeff * pred_mean + image_coeff * image |
| |
|
| | |
| | prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) |
| |
|
| | |
| | sampled_prev_image = prev_image + prev_variance |
| | image = sampled_prev_image |
| |
|
| | return image |
| |
|