import math import torch as th import torch.nn as nn import torch.nn.functional as F from .nn import timestep_embedding def dec2bin(xinp, bits): mask = 2 ** th.arange(bits - 1, -1, -1).to(xinp.device, xinp.dtype) return xinp.unsqueeze(-1).bitwise_and(mask).ne(0).float() class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = th.arange(max_len).unsqueeze(1) div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = th.zeros(1, max_len, d_model) pe[0, :, 0::2] = th.sin(position * div_term) pe[0, :, 1::2] = th.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): """ Args: x: Tensor, shape [batch_size, seq_len, embedding_dim] """ x = x + self.pe[0:1, :x.size(1)] return self.dropout(x) class FeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout, activation): super().__init__() # We set d_ff as a default to 2048 self.linear_1 = nn.Linear(d_model, d_ff) self.dropout = nn.Dropout(dropout) self.linear_2 = nn.Linear(d_ff, d_model) self.activation = activation def forward(self, x): x = self.dropout(self.activation(self.linear_1(x))) x = self.linear_2(x) return x def attention(q, k, v, d_k, mask=None, dropout=None): scores = th.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: mask = mask.unsqueeze(1) scores = scores.masked_fill(mask == 1, -1e9) scores = F.softmax(scores, dim=-1) if dropout is not None: scores = dropout(scores) output = th.matmul(scores, v) return output class MultiHeadAttention(nn.Module): def __init__(self, heads, d_model, dropout = 0.1): super().__init__() self.d_model = d_model self.d_k = d_model // heads self.h = heads self.q_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.out = nn.Linear(d_model, d_model) def forward(self, q, k, v, mask=None): bs = q.size(0) # perform linear operation and split into h heads k = self.k_linear(k).view(bs, -1, self.h, self.d_k) q = self.q_linear(q).view(bs, -1, self.h, self.d_k) v = self.v_linear(v).view(bs, -1, self.h, self.d_k) # transpose to get dimensions bs * h * sl * d_model k = k.transpose(1,2) q = q.transpose(1,2) v = v.transpose(1,2)# calculate attention using function we will define next #TODO # mask = mask.to('cuda:0') scores = attention(q, k, v, self.d_k, mask, self.dropout) # concatenate heads and put through final linear layer concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model) output = self.out(concat) return output class EncoderLayer(nn.Module): def __init__(self, d_model, heads, dropout, activation): super().__init__() self.norm_1 = nn.InstanceNorm1d(d_model) self.norm_2 = nn.InstanceNorm1d(d_model) self.self_attn = MultiHeadAttention(heads, d_model) self.door_attn = MultiHeadAttention(heads, d_model) self.gen_attn = MultiHeadAttention(heads, d_model) self.ff = FeedForward(d_model, d_model*2, dropout, activation) self.dropout = nn.Dropout(dropout) def forward(self, x, door_mask, self_mask, gen_mask): assert (gen_mask.max()==1 and gen_mask.min()==0), f"{gen_mask.max()}, {gen_mask.min()}" x2 = self.norm_1(x) x = x + self.dropout(self.door_attn(x2,x2,x2,door_mask)) \ + self.dropout(self.self_attn(x2, x2, x2, self_mask)) \ + self.dropout(self.gen_attn(x2, x2, x2, gen_mask)) x2 = self.norm_2(x) x = x + self.dropout(self.ff(x2)) return x class TransformerModel(nn.Module): """ The full Transformer model with timestep embedding. """ def __init__( self, in_channels, condition_channels, model_channels, out_channels, dataset, use_checkpoint, use_unet, analog_bit, ): super().__init__() self.in_channels = in_channels self.condition_channels = condition_channels self.model_channels = model_channels self.out_channels = out_channels self.time_channels = model_channels self.use_checkpoint = use_checkpoint self.analog_bit = analog_bit self.use_unet = use_unet self.num_layers = 4 # self.pos_encoder = PositionalEncoding(model_channels, 0.001) # self.activation = nn.SiLU() self.activation = nn.ReLU() self.time_embed = nn.Sequential( nn.Linear(self.model_channels, self.model_channels), nn.SiLU(), nn.Linear(self.model_channels, self.time_channels), ) self.input_emb = nn.Linear(self.in_channels, self.model_channels) self.condition_emb = nn.Linear(self.condition_channels, self.model_channels) if use_unet: self.unet = UNet(self.model_channels, 1) self.transformer_layers = nn.ModuleList([EncoderLayer(self.model_channels, 4, 0.1, self.activation) for x in range(self.num_layers)]) # self.transformer_layers = nn.ModuleList([nn.TransformerEncoderLayer(self.model_channels, 4, self.model_channels*2, 0.1, self.activation, batch_first=True) for x in range(self.num_layers)]) self.output_linear1 = nn.Linear(self.model_channels, self.model_channels) self.output_linear2 = nn.Linear(self.model_channels, self.model_channels//2) self.output_linear3 = nn.Linear(self.model_channels//2, self.out_channels) if not self.analog_bit: self.output_linear_bin1 = nn.Linear(162+self.model_channels, self.model_channels) self.output_linear_bin2 = EncoderLayer(self.model_channels, 1, 0.1, self.activation) self.output_linear_bin3 = EncoderLayer(self.model_channels, 1, 0.1, self.activation) self.output_linear_bin4 = nn.Linear(self.model_channels, 16) print(f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}") def expand_points(self, points, connections): def average_points(point1, point2): points_new = (point1+point2)/2 return points_new p1 = points p1 = p1.view([p1.shape[0], p1.shape[1], 2, -1]) p5 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()] p5 = p5.view([p5.shape[0], p5.shape[1], 2, -1]) p3 = average_points(p1, p5) p2 = average_points(p1, p3) p4 = average_points(p3, p5) p1_5 = average_points(p1, p2) p2_5 = average_points(p2, p3) p3_5 = average_points(p3, p4) p4_5 = average_points(p4, p5) points_new = th.cat((p1.view_as(points), p1_5.view_as(points), p2.view_as(points), p2_5.view_as(points), p3.view_as(points), p3_5.view_as(points), p4.view_as(points), p4_5.view_as(points), p5.view_as(points)), 2) return points_new.detach() def create_image(self, points, connections, room_indices, img_size=256, res=200): img = th.zeros((points.shape[0], 1, img_size, img_size), device=points.device) points = (points+1)*(img_size//2) points[points>=img_size] = img_size-1 points[points<0] = 0 p1 = points p2 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()] slope = (p2[:,:,1]-p1[:,:,1])/((p2[:,:,0]-p1[:,:,0])) slope[slope.isnan()] = 0 slope[slope.isinf()] = 1 m = th.linspace(0, 1, res, device=points.device) new_shape = [p2.shape[0], res, p2.shape[1], p2.shape[2]] new_p2 = p2.unsqueeze(1).expand(new_shape) new_p1 = p1.unsqueeze(1).expand(new_shape) new_room_indices = room_indices.unsqueeze(1).expand([p2.shape[0], res, p2.shape[1], 1]) inc = new_p2 - new_p1 xs = m.view(1,-1,1) * inc[:,:,:,0] xs = xs + new_p1[:,:,:,0] xs = xs.long() x_inc = th.where(inc[:,:,:,0]==0, inc[:,:,:,1], inc[:,:,:,0]) x_inc = m.view(1,-1,1) * x_inc ys = x_inc * slope.unsqueeze(1) + new_p1[:,:,:,1] ys = ys.long() img[th.arange(xs.shape[0])[:, None], :, xs.view(img.shape[0], -1), ys.view(img.shape[0], -1)] = new_room_indices.reshape(img.shape[0], -1, 1).float() return img.detach() def forward(self, x, timesteps, xtalpha, epsalpha, is_syn=False, **kwargs): """ Apply the model to an input batch. :param x: an [N x S x C] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x S x C] Tensor of outputs. """ # prefix = 'syn_' if is_syn else '' prefix = 'syn_' if is_syn else '' x = x.permute([0, 2, 1]).float() # -> convert [N x C x S] to [N x S x C] if not self.analog_bit: x = self.expand_points(x, kwargs[f'{prefix}connections']) # Different input embeddings (Input, Time, Conditions) #TODO--------------------------------------------------------------- # x = x.to('cuda:0') # timesteps = timesteps.to(x.device) # print(x.device) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = time_emb.unsqueeze(1) input_emb = self.input_emb(x) if self.condition_channels>0: cond = None for key in [f'{prefix}room_types', f'{prefix}corner_indices', f'{prefix}room_indices']: if cond is None: cond = kwargs[key] else: cond = th.cat((cond, kwargs[key]), 2) #TODO # cond = cond.to('cuda:0') cond_emb = self.condition_emb(cond.float()) # PositionalEncoding and DM model out = input_emb + cond_emb + time_emb.repeat((1, input_emb.shape[1], 1)) for layer in self.transformer_layers: out = layer(out, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) out_dec = self.output_linear1(out) out_dec = self.activation(out_dec) out_dec = self.output_linear2(out_dec) out_dec = self.output_linear3(out_dec) if not self.analog_bit: out_bin_start = x*xtalpha.repeat([1,1,9]) - out_dec.repeat([1,1,9]) * epsalpha.repeat([1,1,9]) out_bin = (out_bin_start/2 + 0.5) # -> [0,1] out_bin = out_bin * 256 #-> [0, 256] out_bin = dec2bin(out_bin.round().int(), 8) out_bin_inp = out_bin.reshape([x.shape[0], x.shape[1], 16*9]) out_bin_inp[out_bin_inp==0] = -1 out_bin = th.cat((out_bin_start, out_bin_inp, cond_emb), 2) out_bin = self.activation(self.output_linear_bin1(out_bin)) out_bin = self.output_linear_bin2(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) out_bin = self.output_linear_bin3(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) out_bin = self.output_linear_bin4(out_bin) out_bin = out_bin.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S] out_dec = out_dec.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S] if not self.analog_bit: return out_dec, out_bin else: return out_dec, None