Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import math | |
| class Unet(nn.Module): | |
| def __init__( | |
| self, | |
| dim_model, | |
| num_heads, | |
| num_layers, | |
| dropout_p, | |
| dim_input, | |
| dim_output, | |
| free_p=0.1, | |
| text_emb=True, | |
| device='cuda', | |
| **kwargs | |
| ): | |
| super().__init__() | |
| # INFO | |
| self.model_type = "Transformer" | |
| self.dim_model = dim_model | |
| self.text_emb = text_emb | |
| self.dim_input = dim_input | |
| self.device = device | |
| try: | |
| self.Disc = kwargs['Disc'] | |
| except: | |
| self.Disc = False | |
| # layers | |
| self.free_p = free_p | |
| self.positional_encoder = PositionalEncoding( | |
| dim_model=dim_model, dropout_p=dropout_p, max_len=5000 | |
| ) | |
| self.embedding_input = nn.Linear(dim_input, dim_model) | |
| self.embedding_original = nn.Linear(dim_input, dim_model) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, | |
| nhead=num_heads, | |
| dim_feedforward=dim_model*4, | |
| dropout=dropout_p, | |
| activation="gelu", | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, | |
| num_layers=num_layers, | |
| ) | |
| if self.Disc: | |
| # for discriminator | |
| self.pred = nn.Sequential(nn.Linear(dim_output, dim_output), | |
| nn.SiLU(inplace=False), | |
| nn.Linear(dim_output, 1), | |
| nn.Sigmoid()) | |
| self.out = nn.Linear(dim_model, dim_output) | |
| self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder) | |
| if self.text_emb: | |
| #for embedding progress indicator | |
| print("text embedding is enabled!") | |
| self.positional_encoder_pi = PositionalEncoding( | |
| dim_model=dim_model, dropout_p=dropout_p, max_len=5000 | |
| ) | |
| self.embed_prog_ind = ProgIndEmbedder(self.dim_model, self.positional_encoder_pi) | |
| def forward_disc(self, x, timesteps): | |
| t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding | |
| x, t_emb = x.permute(1, 0, 2), t_emb.permute(1, 0, 2) | |
| x = self.embedding_input(x) * math.sqrt(self.dim_model) | |
| x = torch.cat((t_emb, x), dim=0) | |
| x = self.positional_encoder(x) | |
| x = self.transformer(x) | |
| output = self.out(x)[1:] | |
| output = output.permute(1, 0, 2) | |
| output = output.mean(dim=1) | |
| output = self.pred(output) | |
| return output | |
| def forward_(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None): | |
| t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding | |
| if self.text_emb: | |
| text_emb = text_emb.unsqueeze(1) # batchsize, 1, 512 | |
| assert text_emb.shape == (x.shape[0], 1, self.dim_model), \ | |
| f'text_emb shape should be (batchsize, 1, {self.dim_model})' | |
| x, joints_orig, t_emb = x.permute(1, 0, 2), joints_orig.permute(1, 0, 2), t_emb.permute(1, 0, 2) | |
| x = self.embedding_input(x) * math.sqrt(self.dim_model) | |
| joints_orig = self.embedding_original(joints_orig) * math.sqrt(self.dim_model) | |
| x = (x + joints_orig) / 2. | |
| if not self.text_emb: | |
| x = torch.cat((t_emb, x), dim=0) # (seq_len+1), batchsize, dim_model | |
| else: | |
| text_emb = text_emb.permute(1, 0, 2) | |
| prog_ind = (prog_ind*100).round().to(torch.int64) | |
| prog_ind_emb = self.embed_prog_ind(prog_ind).permute(1, 0, 2) | |
| t_emb = (t_emb + text_emb/10.0 + prog_ind_emb) * math.sqrt(self.dim_model) | |
| x = torch.cat((t_emb, x), dim=0) | |
| x = self.positional_encoder(x) | |
| x = self.transformer(x) | |
| output = self.out(x)[1:] | |
| output = output.permute(1, 0, 2) | |
| return output | |
| def forward(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None): | |
| if self.Disc: | |
| return self.forward_disc(x, timesteps) | |
| else: | |
| return self.forward_(x, timesteps, text_emb, prog_ind, joints_orig) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, dim_model, dropout_p, max_len): | |
| super().__init__() | |
| # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html | |
| # max_len determines how far the position can have an effect on a token (window) | |
| # Info | |
| self.dropout = nn.Dropout(dropout_p) | |
| # Encoding - From formula | |
| pos_encoding = torch.zeros(max_len, dim_model) | |
| positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5 | |
| division_term = torch.exp( | |
| torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model) | |
| # PE(pos, 2i) = sin(pos/1000^(2i/dim_model)) | |
| pos_encoding[:, 0::2] = torch.sin(positions_list * division_term) | |
| # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model)) | |
| pos_encoding[:, 1::2] = torch.cos(positions_list * division_term) | |
| # Saving buffer (same as parameter without gradients needed) | |
| pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer("pos_encoding", pos_encoding) | |
| def forward(self, token_embedding: torch.tensor) -> torch.tensor: | |
| # Residual connection + pos encoding | |
| return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :]) | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, latent_dim, sequence_pos_encoder): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.sequence_pos_encoder = sequence_pos_encoder | |
| time_embed_dim = self.latent_dim | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(self.latent_dim, time_embed_dim), | |
| nn.SiLU(inplace=False), | |
| nn.Linear(time_embed_dim, time_embed_dim), | |
| ) | |
| def forward(self, timesteps): | |
| return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2) | |
| # totally the same as TimeStepEmbedder | |
| class ProgIndEmbedder(nn.Module): | |
| def __init__(self, latent_dim, sequence_pos_encoder): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.sequence_pos_encoder = sequence_pos_encoder | |
| time_embed_dim = self.latent_dim | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(self.latent_dim, time_embed_dim), | |
| nn.SiLU(inplace=False), | |
| nn.Linear(time_embed_dim, time_embed_dim), | |
| ) | |
| def forward(self, timesteps): | |
| return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2) | |