Spaces:
Runtime error
Runtime error
| import math | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .nn import timestep_embedding | |
| def dec2bin(xinp, bits): | |
| mask = 2 ** th.arange(bits - 1, -1, -1).to(xinp.device, xinp.dtype) | |
| return xinp.unsqueeze(-1).bitwise_and(mask).ne(0).float() | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| position = th.arange(max_len).unsqueeze(1) | |
| div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
| pe = th.zeros(1, max_len, d_model) | |
| pe[0, :, 0::2] = th.sin(position * div_term) | |
| pe[0, :, 1::2] = th.cos(position * div_term) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Tensor, shape [batch_size, seq_len, embedding_dim] | |
| """ | |
| x = x + self.pe[0:1, :x.size(1)] | |
| return self.dropout(x) | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout, activation): | |
| super().__init__() | |
| # We set d_ff as a default to 2048 | |
| self.linear_1 = nn.Linear(d_model, d_ff) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear_2 = nn.Linear(d_ff, d_model) | |
| self.activation = activation | |
| def forward(self, x): | |
| x = self.dropout(self.activation(self.linear_1(x))) | |
| x = self.linear_2(x) | |
| return x | |
| def attention(q, k, v, d_k, mask=None, dropout=None): | |
| scores = th.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) | |
| if mask is not None: | |
| mask = mask.unsqueeze(1) | |
| scores = scores.masked_fill(mask == 1, -1e9) | |
| scores = F.softmax(scores, dim=-1) | |
| if dropout is not None: | |
| scores = dropout(scores) | |
| output = th.matmul(scores, v) | |
| return output | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, heads, d_model, dropout = 0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.d_k = d_model // heads | |
| self.h = heads | |
| self.q_linear = nn.Linear(d_model, d_model) | |
| self.v_linear = nn.Linear(d_model, d_model) | |
| self.k_linear = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.out = nn.Linear(d_model, d_model) | |
| def forward(self, q, k, v, mask=None): | |
| bs = q.size(0) | |
| # perform linear operation and split into h heads | |
| k = self.k_linear(k).view(bs, -1, self.h, self.d_k) | |
| q = self.q_linear(q).view(bs, -1, self.h, self.d_k) | |
| v = self.v_linear(v).view(bs, -1, self.h, self.d_k) | |
| # transpose to get dimensions bs * h * sl * d_model | |
| k = k.transpose(1,2) | |
| q = q.transpose(1,2) | |
| v = v.transpose(1,2)# calculate attention using function we will define next | |
| #TODO | |
| # mask = mask.to('cuda:0') | |
| scores = attention(q, k, v, self.d_k, mask, self.dropout) | |
| # concatenate heads and put through final linear layer | |
| concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model) | |
| output = self.out(concat) | |
| return output | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model, heads, dropout, activation): | |
| super().__init__() | |
| self.norm_1 = nn.InstanceNorm1d(d_model) | |
| self.norm_2 = nn.InstanceNorm1d(d_model) | |
| self.self_attn = MultiHeadAttention(heads, d_model) | |
| self.door_attn = MultiHeadAttention(heads, d_model) | |
| self.gen_attn = MultiHeadAttention(heads, d_model) | |
| self.ff = FeedForward(d_model, d_model*2, dropout, activation) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, door_mask, self_mask, gen_mask): | |
| assert (gen_mask.max()==1 and gen_mask.min()==0), f"{gen_mask.max()}, {gen_mask.min()}" | |
| x2 = self.norm_1(x) | |
| x = x + self.dropout(self.door_attn(x2,x2,x2,door_mask)) \ | |
| + self.dropout(self.self_attn(x2, x2, x2, self_mask)) \ | |
| + self.dropout(self.gen_attn(x2, x2, x2, gen_mask)) | |
| x2 = self.norm_2(x) | |
| x = x + self.dropout(self.ff(x2)) | |
| return x | |
| class TransformerModel(nn.Module): | |
| """ | |
| The full Transformer model with timestep embedding. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| condition_channels, | |
| model_channels, | |
| out_channels, | |
| dataset, | |
| use_checkpoint, | |
| use_unet, | |
| analog_bit, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.condition_channels = condition_channels | |
| self.model_channels = model_channels | |
| self.out_channels = out_channels | |
| self.time_channels = model_channels | |
| self.use_checkpoint = use_checkpoint | |
| self.analog_bit = analog_bit | |
| self.use_unet = use_unet | |
| self.num_layers = 4 | |
| # self.pos_encoder = PositionalEncoding(model_channels, 0.001) | |
| # self.activation = nn.SiLU() | |
| self.activation = nn.ReLU() | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(self.model_channels, self.model_channels), | |
| nn.SiLU(), | |
| nn.Linear(self.model_channels, self.time_channels), | |
| ) | |
| self.input_emb = nn.Linear(self.in_channels, self.model_channels) | |
| self.condition_emb = nn.Linear(self.condition_channels, self.model_channels) | |
| if use_unet: | |
| self.unet = UNet(self.model_channels, 1) | |
| self.transformer_layers = nn.ModuleList([EncoderLayer(self.model_channels, 4, 0.1, self.activation) for x in range(self.num_layers)]) | |
| # self.transformer_layers = nn.ModuleList([nn.TransformerEncoderLayer(self.model_channels, 4, self.model_channels*2, 0.1, self.activation, batch_first=True) for x in range(self.num_layers)]) | |
| self.output_linear1 = nn.Linear(self.model_channels, self.model_channels) | |
| self.output_linear2 = nn.Linear(self.model_channels, self.model_channels//2) | |
| self.output_linear3 = nn.Linear(self.model_channels//2, self.out_channels) | |
| if not self.analog_bit: | |
| self.output_linear_bin1 = nn.Linear(162+self.model_channels, self.model_channels) | |
| self.output_linear_bin2 = EncoderLayer(self.model_channels, 1, 0.1, self.activation) | |
| self.output_linear_bin3 = EncoderLayer(self.model_channels, 1, 0.1, self.activation) | |
| self.output_linear_bin4 = nn.Linear(self.model_channels, 16) | |
| print(f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}") | |
| def expand_points(self, points, connections): | |
| def average_points(point1, point2): | |
| points_new = (point1+point2)/2 | |
| return points_new | |
| p1 = points | |
| p1 = p1.view([p1.shape[0], p1.shape[1], 2, -1]) | |
| p5 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()] | |
| p5 = p5.view([p5.shape[0], p5.shape[1], 2, -1]) | |
| p3 = average_points(p1, p5) | |
| p2 = average_points(p1, p3) | |
| p4 = average_points(p3, p5) | |
| p1_5 = average_points(p1, p2) | |
| p2_5 = average_points(p2, p3) | |
| p3_5 = average_points(p3, p4) | |
| p4_5 = average_points(p4, p5) | |
| points_new = th.cat((p1.view_as(points), p1_5.view_as(points), p2.view_as(points), | |
| p2_5.view_as(points), p3.view_as(points), p3_5.view_as(points), p4.view_as(points), p4_5.view_as(points), p5.view_as(points)), 2) | |
| return points_new.detach() | |
| def create_image(self, points, connections, room_indices, img_size=256, res=200): | |
| img = th.zeros((points.shape[0], 1, img_size, img_size), device=points.device) | |
| points = (points+1)*(img_size//2) | |
| points[points>=img_size] = img_size-1 | |
| points[points<0] = 0 | |
| p1 = points | |
| p2 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()] | |
| slope = (p2[:,:,1]-p1[:,:,1])/((p2[:,:,0]-p1[:,:,0])) | |
| slope[slope.isnan()] = 0 | |
| slope[slope.isinf()] = 1 | |
| m = th.linspace(0, 1, res, device=points.device) | |
| new_shape = [p2.shape[0], res, p2.shape[1], p2.shape[2]] | |
| new_p2 = p2.unsqueeze(1).expand(new_shape) | |
| new_p1 = p1.unsqueeze(1).expand(new_shape) | |
| new_room_indices = room_indices.unsqueeze(1).expand([p2.shape[0], res, p2.shape[1], 1]) | |
| inc = new_p2 - new_p1 | |
| xs = m.view(1,-1,1) * inc[:,:,:,0] | |
| xs = xs + new_p1[:,:,:,0] | |
| xs = xs.long() | |
| x_inc = th.where(inc[:,:,:,0]==0, inc[:,:,:,1], inc[:,:,:,0]) | |
| x_inc = m.view(1,-1,1) * x_inc | |
| ys = x_inc * slope.unsqueeze(1) + new_p1[:,:,:,1] | |
| ys = ys.long() | |
| img[th.arange(xs.shape[0])[:, None], :, xs.view(img.shape[0], -1), ys.view(img.shape[0], -1)] = new_room_indices.reshape(img.shape[0], -1, 1).float() | |
| return img.detach() | |
| def forward(self, x, timesteps, xtalpha, epsalpha, is_syn=False, **kwargs): | |
| """ | |
| Apply the model to an input batch. | |
| :param x: an [N x S x C] Tensor of inputs. | |
| :param timesteps: a 1-D batch of timesteps. | |
| :param y: an [N] Tensor of labels, if class-conditional. | |
| :return: an [N x S x C] Tensor of outputs. | |
| """ | |
| # prefix = 'syn_' if is_syn else '' | |
| prefix = 'syn_' if is_syn else '' | |
| x = x.permute([0, 2, 1]).float() # -> convert [N x C x S] to [N x S x C] | |
| if not self.analog_bit: | |
| x = self.expand_points(x, kwargs[f'{prefix}connections']) | |
| # Different input embeddings (Input, Time, Conditions) | |
| #TODO--------------------------------------------------------------- | |
| # x = x.to('cuda:0') | |
| # timesteps = timesteps.to(x.device) | |
| # print(x.device) | |
| time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) | |
| time_emb = time_emb.unsqueeze(1) | |
| input_emb = self.input_emb(x) | |
| if self.condition_channels>0: | |
| cond = None | |
| for key in [f'{prefix}room_types', f'{prefix}corner_indices', f'{prefix}room_indices']: | |
| if cond is None: | |
| cond = kwargs[key] | |
| else: | |
| cond = th.cat((cond, kwargs[key]), 2) | |
| #TODO | |
| # cond = cond.to('cuda:0') | |
| cond_emb = self.condition_emb(cond.float()) | |
| # PositionalEncoding and DM model | |
| out = input_emb + cond_emb + time_emb.repeat((1, input_emb.shape[1], 1)) | |
| for layer in self.transformer_layers: | |
| out = layer(out, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) | |
| out_dec = self.output_linear1(out) | |
| out_dec = self.activation(out_dec) | |
| out_dec = self.output_linear2(out_dec) | |
| out_dec = self.output_linear3(out_dec) | |
| if not self.analog_bit: | |
| out_bin_start = x*xtalpha.repeat([1,1,9]) - out_dec.repeat([1,1,9]) * epsalpha.repeat([1,1,9]) | |
| out_bin = (out_bin_start/2 + 0.5) # -> [0,1] | |
| out_bin = out_bin * 256 #-> [0, 256] | |
| out_bin = dec2bin(out_bin.round().int(), 8) | |
| out_bin_inp = out_bin.reshape([x.shape[0], x.shape[1], 16*9]) | |
| out_bin_inp[out_bin_inp==0] = -1 | |
| out_bin = th.cat((out_bin_start, out_bin_inp, cond_emb), 2) | |
| out_bin = self.activation(self.output_linear_bin1(out_bin)) | |
| out_bin = self.output_linear_bin2(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) | |
| out_bin = self.output_linear_bin3(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) | |
| out_bin = self.output_linear_bin4(out_bin) | |
| out_bin = out_bin.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S] | |
| out_dec = out_dec.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S] | |
| if not self.analog_bit: | |
| return out_dec, out_bin | |
| else: | |
| return out_dec, None | |