| import torch |
| import torch.nn as nn |
| import math |
| import torch.nn.functional as F |
| import numpy as np |
| import omegaconf |
|
|
| import transformers |
| from einops import rearrange |
| from .dit import LabelEmbedder, EmbeddingLayer |
|
|
|
|
| |
| |
| def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000): |
| assert len(timesteps.shape) == 1 |
| half_dim = embedding_dim // 2 |
| |
| emb = math.log(max_positions) / (half_dim - 1) |
| |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) |
| |
| |
| emb = timesteps.float()[:, None] * emb[None, :] |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| if embedding_dim % 2 == 1: |
| emb = F.pad(emb, (0, 1), mode='constant') |
| assert emb.shape == (timesteps.shape[0], embedding_dim) |
| return emb |
|
|
|
|
| |
| def variance_scaling(scale, mode, distribution, |
| in_axis=1, out_axis=0, |
| dtype=torch.float32, |
| device='cpu'): |
| """Ported from JAX. """ |
|
|
| def _compute_fans(shape, in_axis=1, out_axis=0): |
| receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] |
| fan_in = shape[in_axis] * receptive_field_size |
| fan_out = shape[out_axis] * receptive_field_size |
| return fan_in, fan_out |
|
|
| def init(shape, dtype=dtype, device=device): |
| fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) |
| if mode == "fan_in": |
| denominator = fan_in |
| elif mode == "fan_out": |
| denominator = fan_out |
| elif mode == "fan_avg": |
| denominator = (fan_in + fan_out) / 2 |
| else: |
| raise ValueError( |
| "invalid mode for variance scaling initializer: {}".format(mode)) |
| variance = scale / denominator |
| if distribution == "normal": |
| return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) |
| elif distribution == "uniform": |
| return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) |
| else: |
| raise ValueError("invalid distribution for variance scaling initializer") |
|
|
| return init |
|
|
|
|
| def default_init(scale=1.): |
| """The same initialization used in DDPM.""" |
| scale = 1e-10 if scale == 0 else scale |
| return variance_scaling(scale, 'fan_avg', 'uniform') |
|
|
|
|
| class NiN(nn.Module): |
| def __init__(self, in_ch, out_ch, init_scale=0.1): |
| super().__init__() |
| self.W = nn.Parameter(default_init(scale=init_scale)((in_ch, out_ch)), requires_grad=True) |
| self.b = nn.Parameter(torch.zeros(out_ch), requires_grad=True) |
|
|
| def forward(self, x, |
| ): |
|
|
| x = x.permute(0, 2, 3, 1) |
| |
| y = torch.einsum('bhwi,ik->bhwk', x, self.W) + self.b |
| |
| return y.permute(0, 3, 1, 2) |
|
|
| class AttnBlock(nn.Module): |
| """Channel-wise self-attention block.""" |
| def __init__(self, channels, skip_rescale=True): |
| super().__init__() |
| self.skip_rescale = skip_rescale |
| self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels//4, 32), |
| num_channels=channels, eps=1e-6) |
| self.NIN_0 = NiN(channels, channels) |
| self.NIN_1 = NiN(channels, channels) |
| self.NIN_2 = NiN(channels, channels) |
| self.NIN_3 = NiN(channels, channels, init_scale=0.) |
|
|
| def forward(self, x, |
| ): |
|
|
| B, C, H, W = x.shape |
| h = self.GroupNorm_0(x) |
| q = self.NIN_0(h) |
| k = self.NIN_1(h) |
| v = self.NIN_2(h) |
|
|
| w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) |
| w = torch.reshape(w, (B, H, W, H * W)) |
| w = F.softmax(w, dim=-1) |
| w = torch.reshape(w, (B, H, W, H, W)) |
| h = torch.einsum('bhwij,bcij->bchw', w, v) |
| h = self.NIN_3(h) |
|
|
| if self.skip_rescale: |
| return (x + h) / np.sqrt(2.) |
| else: |
| return x + h |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_rescale=True): |
| super().__init__() |
|
|
| self.in_ch = in_ch |
| self.out_ch = out_ch |
|
|
| self.skip_rescale = skip_rescale |
|
|
| self.act = nn.functional.silu |
| self.groupnorm0 = nn.GroupNorm( |
| num_groups=min(in_ch // 4, 32), |
| num_channels=in_ch, eps=1e-6 |
| ) |
| self.conv0 = nn.Conv2d( |
| in_ch, out_ch, kernel_size=3, padding=1 |
| ) |
|
|
| if temb_dim is not None: |
| self.dense0 = nn.Linear(temb_dim, out_ch) |
| nn.init.zeros_(self.dense0.bias) |
|
|
|
|
| self.groupnorm1 = nn.GroupNorm( |
| num_groups=min(out_ch // 4, 32), |
| num_channels=out_ch, eps=1e-6 |
| ) |
| self.dropout0 = nn.Dropout(dropout) |
|
|
| self.conv1 = nn.Conv2d( |
| out_ch, out_ch, kernel_size=3, padding=1 |
| ) |
| if out_ch != in_ch: |
| self.nin = NiN(in_ch, out_ch) |
|
|
| def forward(self, x, |
| temb=None, |
| ): |
|
|
| assert x.shape[1] == self.in_ch |
|
|
| h = self.groupnorm0(x) |
| h = self.act(h) |
| h = self.conv0(h) |
|
|
| if temb is not None: |
| h += self.dense0(self.act(temb))[:, :, None, None] |
|
|
| h = self.groupnorm1(h) |
| h = self.act(h) |
| h = self.dropout0(h) |
| h = self.conv1(h) |
| if h.shape[1] != self.in_ch: |
| x = self.nin(x) |
|
|
| assert x.shape == h.shape |
|
|
| if self.skip_rescale: |
| return (x + h) / np.sqrt(2.) |
| else: |
| return x + h |
|
|
| class Downsample(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, kernel_size=3, |
| stride=2, padding=0) |
|
|
| def forward(self, x, |
| ): |
| B, C, H, W = x.shape |
| x = nn.functional.pad(x, (0, 1, 0, 1)) |
| x= self.conv(x) |
|
|
| assert x.shape == (B, C, H // 2, W // 2) |
| return x |
|
|
| class Upsample(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
|
|
| def forward(self, x, |
| ): |
| B, C, H, W = x.shape |
| h = F.interpolate(x, (H*2, W*2), mode='nearest') |
| h = self.conv(h) |
|
|
| assert h.shape == (B, C, H*2, W*2) |
| return h |
|
|
|
|
| class UNet(nn.Module): |
| def __init__(self, config, vocab_size=None): |
| super().__init__() |
| if type(config) == dict: |
| config = omegaconf.OmegaConf.create(config) |
| self.ch = config.model.ch |
| self.num_res_blocks = config.model.num_res_blocks |
| self.num_scales = config.model.num_scales |
| self.ch_mult = config.model.ch_mult |
| assert self.num_scales == len(self.ch_mult) |
| self.input_channels = config.model.input_channels |
| self.output_channels = 2 * config.model.input_channels |
| self.scale_count_to_put_attn = config.model.scale_count_to_put_attn |
| self.data_min_max = [0, vocab_size] |
| self.dropout = config.model.dropout |
| self.skip_rescale = config.model.skip_rescale |
| self.time_conditioning = config.model.time_conditioning |
| self.time_scale_factor = config.model.time_scale_factor |
| self.time_embed_dim = config.model.time_embed_dim |
| self.vocab_size = vocab_size |
|
|
| self.size = config.model.size |
| self.length = config.model.length |
|
|
| |
| self.fix_logistic = config.model.fix_logistic |
|
|
| self.act = nn.functional.silu |
|
|
| if self.time_conditioning: |
| self.temb_modules = [] |
| self.temb_modules.append(nn.Linear(self.time_embed_dim, self.time_embed_dim*4)) |
| nn.init.zeros_(self.temb_modules[-1].bias) |
| self.temb_modules.append(nn.Linear(self.time_embed_dim*4, self.time_embed_dim*4)) |
| nn.init.zeros_(self.temb_modules[-1].bias) |
| self.temb_modules = nn.ModuleList(self.temb_modules) |
|
|
| self.expanded_time_dim = 4 * self.time_embed_dim if self.time_conditioning else None |
|
|
| self.input_conv = nn.Conv2d( |
| in_channels=self.input_channels, out_channels=self.ch, |
| kernel_size=3, padding=1 |
| ) |
|
|
| h_cs = [self.ch] |
| in_ch = self.ch |
|
|
| |
| self.downsampling_modules = [] |
|
|
| for scale_count in range(self.num_scales): |
| for res_count in range(self.num_res_blocks): |
| out_ch = self.ch * self.ch_mult[scale_count] |
| self.downsampling_modules.append( |
| ResBlock(in_ch, out_ch, temb_dim=self.expanded_time_dim, |
| dropout=self.dropout, skip_rescale=self.skip_rescale) |
| ) |
| in_ch = out_ch |
| h_cs.append(in_ch) |
| if scale_count == self.scale_count_to_put_attn: |
| self.downsampling_modules.append( |
| AttnBlock(in_ch, skip_rescale=self.skip_rescale) |
| ) |
|
|
| if scale_count != self.num_scales - 1: |
| self.downsampling_modules.append(Downsample(in_ch)) |
| h_cs.append(in_ch) |
|
|
| self.downsampling_modules = nn.ModuleList(self.downsampling_modules) |
|
|
| |
| self.middle_modules = [] |
|
|
| self.middle_modules.append( |
| ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim, |
| dropout=self.dropout, skip_rescale=self.skip_rescale) |
| ) |
| self.middle_modules.append( |
| AttnBlock(in_ch, skip_rescale=self.skip_rescale) |
| ) |
| self.middle_modules.append( |
| ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim, |
| dropout=self.dropout, skip_rescale=self.skip_rescale) |
| ) |
| self.middle_modules = nn.ModuleList(self.middle_modules) |
|
|
| |
| self.upsampling_modules = [] |
|
|
| for scale_count in reversed(range(self.num_scales)): |
| for res_count in range(self.num_res_blocks+1): |
| out_ch = self.ch * self.ch_mult[scale_count] |
| self.upsampling_modules.append( |
| ResBlock(in_ch + h_cs.pop(), |
| out_ch, |
| temb_dim=self.expanded_time_dim, |
| dropout=self.dropout, |
| skip_rescale=self.skip_rescale |
| ) |
| ) |
| in_ch = out_ch |
|
|
| if scale_count == self.scale_count_to_put_attn: |
| self.upsampling_modules.append( |
| AttnBlock(in_ch, skip_rescale=self.skip_rescale) |
| ) |
| if scale_count != 0: |
| self.upsampling_modules.append(Upsample(in_ch)) |
|
|
| self.upsampling_modules = nn.ModuleList(self.upsampling_modules) |
|
|
| assert len(h_cs) == 0 |
|
|
| |
| self.output_modules = [] |
| |
| self.output_modules.append( |
| nn.GroupNorm(min(in_ch//4, 32), in_ch, eps=1e-6) |
| ) |
|
|
| self.output_modules.append( |
| nn.Conv2d(in_ch, self.output_channels, kernel_size=3, padding=1) |
| ) |
| self.output_modules = nn.ModuleList(self.output_modules) |
|
|
| if config.training.guidance: |
| self.cond_map = LabelEmbedder( |
| config.data.num_classes + 1, |
| self.time_embed_dim*4) |
| else: |
| self.cond_map = None |
|
|
| def _center_data(self, x): |
| out = (x - self.data_min_max[0]) / (self.data_min_max[1] - self.data_min_max[0]) |
| return 2 * out - 1 |
|
|
| def _time_embedding(self, timesteps): |
| if self.time_conditioning: |
| temb = transformer_timestep_embedding( |
| timesteps * self.time_scale_factor, self.time_embed_dim |
| ) |
| temb = self.temb_modules[0](temb) |
| temb = self.temb_modules[1](self.act(temb)) |
| else: |
| temb = None |
|
|
| return temb |
|
|
| def _do_input_conv(self, h): |
| h = self.input_conv(h) |
| hs = [h] |
| return h, hs |
|
|
| def _do_downsampling(self, h, hs, temb): |
| m_idx = 0 |
| for scale_count in range(self.num_scales): |
| for res_count in range(self.num_res_blocks): |
| h = self.downsampling_modules[m_idx](h, temb) |
| m_idx += 1 |
| if scale_count == self.scale_count_to_put_attn: |
| h = self.downsampling_modules[m_idx](h) |
| m_idx += 1 |
| hs.append(h) |
|
|
| if scale_count != self.num_scales - 1: |
| h = self.downsampling_modules[m_idx](h) |
| hs.append(h) |
| m_idx += 1 |
|
|
| assert m_idx == len(self.downsampling_modules) |
|
|
| return h, hs |
|
|
| def _do_middle(self, h, temb): |
| m_idx = 0 |
| h = self.middle_modules[m_idx](h, temb) |
| m_idx += 1 |
| h = self.middle_modules[m_idx](h) |
| m_idx += 1 |
| h = self.middle_modules[m_idx](h, temb) |
| m_idx += 1 |
|
|
| assert m_idx == len(self.middle_modules) |
|
|
| return h |
|
|
| def _do_upsampling(self, h, hs, temb): |
| m_idx = 0 |
| for scale_count in reversed(range(self.num_scales)): |
| for res_count in range(self.num_res_blocks+1): |
| h = self.upsampling_modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) |
| m_idx += 1 |
|
|
| if scale_count == self.scale_count_to_put_attn: |
| h = self.upsampling_modules[m_idx](h) |
| m_idx += 1 |
|
|
| if scale_count != 0: |
| h = self.upsampling_modules[m_idx](h) |
| m_idx += 1 |
|
|
| assert len(hs) == 0 |
| assert m_idx == len(self.upsampling_modules) |
|
|
| return h |
|
|
| def _do_output(self, h): |
|
|
| h = self.output_modules[0](h) |
| h = self.act(h) |
| h = self.output_modules[1](h) |
|
|
| return h |
|
|
| def _logistic_output_res(self, |
| h, |
| centered_x_in, |
| ): |
| B, twoC, H, W = h.shape |
| C = twoC//2 |
| h[:, 0:C, :, :] = torch.tanh(centered_x_in + h[:, 0:C, :, :]) |
| return h |
|
|
| def _log_minus_exp(self, a, b, eps=1e-6): |
| """ |
| Compute log (exp(a) - exp(b)) for (b<a) |
| From https://arxiv.org/pdf/2107.03006.pdf |
| """ |
| return a + torch.log1p(-torch.exp(b-a) + eps) |
|
|
| |
| def _truncated_logistic_output(self, net_out): |
| B, D = net_out.shape[0], self.length |
| C = 3 |
| S = self.vocab_size |
|
|
| |
| mu = net_out[:, 0:C, :, :].unsqueeze(-1) |
| log_scale = net_out[:, C:, :, :].unsqueeze(-1) |
|
|
| inv_scale = torch.exp(- (log_scale - 2)) |
|
|
| bin_width = 2. / S |
| bin_centers = torch.linspace(start=-1. + bin_width/2, |
| end=1. - bin_width/2, |
| steps=S, |
| device='cuda').view(1, 1, 1, 1, S) |
|
|
| sig_in_left = (bin_centers - bin_width/2 - mu) * inv_scale |
| bin_left_logcdf = F.logsigmoid(sig_in_left) |
| sig_in_right = (bin_centers + bin_width/2 - mu) * inv_scale |
| bin_right_logcdf = F.logsigmoid(sig_in_right) |
|
|
| logits_1 = self._log_minus_exp(bin_right_logcdf, bin_left_logcdf) |
| logits_2 = self._log_minus_exp(-sig_in_left + bin_left_logcdf, -sig_in_right + bin_right_logcdf) |
| if self.fix_logistic: |
| logits = torch.min(logits_1, logits_2) |
| else: |
| logits = logits_1 |
|
|
| logits = logits.view(B,D,S) |
|
|
| return logits |
|
|
|
|
| def forward(self, |
| x, |
| timesteps=None, |
| cond=None, |
| x_emb=None, |
| ): |
| img_size = int(np.sqrt(self.size)) |
|
|
| h = rearrange(x, "b (c h w) -> b c h w", h=img_size, w=img_size, c=3) |
| h = self._center_data(h) |
| centered_x_in = h |
|
|
| temb = self._time_embedding(timesteps) |
| if cond is not None: |
| if self.cond_map is None: |
| raise ValueError("Conditioning variable provided, " |
| "but Model was not initialized " |
| "with condition embedding layer.") |
| else: |
| assert cond.shape == (x.shape[0],) |
| temb = temb + self.cond_map(cond) |
|
|
| h, hs = self._do_input_conv(h) |
|
|
| h, hs = self._do_downsampling(h, hs, temb) |
|
|
| h = self._do_middle(h, temb) |
|
|
| h = self._do_upsampling(h, hs, temb) |
|
|
| h = self._do_output(h) |
|
|
| |
| h = self._logistic_output_res(h, centered_x_in) |
| h = self._truncated_logistic_output(h) |
|
|
| return h |
|
|
|
|
| class UNetConfig(transformers.PretrainedConfig): |
| """Hugging Face configuration class for MDLM.""" |
| model_type = "unet" |
|
|
| def __init__( |
| self, |
| ch: int = 128, |
| num_res_blocks: int = 2, |
| num_scales: int = 4, |
| ch_mult: list = [1, 2, 2, 2], |
| input_channels: int = 3, |
| output_channels: int = 3, |
| scale_count_to_put_attn: int = 1, |
| data_min_max: list = [0, 255], |
| dropout: float = 0.1, |
| skip_rescale: bool = True, |
| time_conditioning: bool = True, |
| time_scale_factor: float = 1000, |
| time_embed_dim: int = 128, |
| fix_logistic: bool = False, |
| vocab_size: int = 256, |
| size: int = 1024, |
| guidance_classifier_free: bool = False, |
| guidance_num_classes: int = -1, |
| cond_dim: int = -1, |
| length: int = 3072, |
| **kwargs): |
|
|
| super().__init__(**kwargs) |
| self.ch = ch |
| self.num_res_blocks = num_res_blocks |
| self.num_scales = num_scales |
| self.ch_mult = ch_mult |
| self.input_channels = input_channels |
| self.output_channels = vocab_size |
| self.scale_count_to_put_attn = scale_count_to_put_attn |
| self.data_min_max = data_min_max |
| self.dropout = dropout |
| self.skip_rescale = skip_rescale |
| self.time_conditioning = time_conditioning |
| self.time_scale_factor = time_scale_factor |
| self.time_embed_dim = time_embed_dim |
| self.fix_logistic = fix_logistic |
|
|
| self.vocab_size = vocab_size |
| self.size = size |
| self.guidance_classifier_free = guidance_classifier_free |
| self.guidance_num_classes = guidance_num_classes |
| self.cond_dim = cond_dim |
| self.length = length |
|
|